#include <gtest/gtest.h>

#include <AH/Filters/EMA.hpp>
#include <algorithm>
#include <array>
#include <cmath>

using AH::EMA;
using AH::EMA_f;

TEST(EMA, EMA) {
    using namespace std;
    EMA<2, uint16_t> ema;
    array<uint16_t, 12> signal = {
        100, 100, 25, 25, 50, 123, 465, 75, 56, 50, 23, 41,
    };
    array<uint16_t, 12> expected = {
        25, 44, 39, 36, 39, 60, 161, 140, 119, 102, 82, 72,
    };
    for_each(signal.begin(), signal.end(), [&](uint16_t &s) { s = ema(s); });
    EXPECT_EQ(signal, expected);

    // In general, the results won't be exactly the same, because of rounding
    // errors. The expected sequence was generated by SciPy using floats, and
    // then the final results were rounded.
    // The EMA filter under test here uses integers for  intermediate results,
    // so the rounding errors will accumulate.
    // In this test, however, there seems to be no problem.
}

TEST(EMA, EMA_overflow) {
    // Test that when the input is the maximum allowed value all the time, 
    // nothing overflows.
    using namespace std;
    EMA<6, uint16_t> ema;
    constexpr size_t N = 2000;
    array<uint16_t, N> signal;
    constexpr uint16_t maximum = (1 << 10) - 1;
    fill(signal.begin(), signal.end(), maximum);
    array<uint16_t, N> expected;
    generate(expected.begin(), expected.end(), [i = 0]() mutable {
        return round(maximum - maximum * pow(1 - 0.015625, ++i));
    });
    for_each(signal.begin(), signal.end(), [&](uint16_t &s) { s = ema(s); });
    for (size_t i = 0; i < signal.size(); ++i)
        EXPECT_NEAR(signal[i], expected[i], 1);
}

TEST(EMA, EMA_f) {
    using namespace std;
    EMA_f ema = 0.75;
    array<float, 12> signal = {
        100.0, 100.0, 25.0, 25.0, 50.0, 123.0,
        465.0, 75.0,  56.0, 50.0, 23.0, 41.0,
    };
    array<float, 12> expected = {
        25.0,      43.75,     39.0625,  35.546875, 39.160156, 60.120117,
        161.34009, 139.75507, 118.8163, 101.61223, 81.95917,  71.719376,
    };
    for_each(signal.begin(), signal.end(), [&](float &s) { s = ema(s); });
    // ASSERT_EQ(signal, expected);
    for (size_t i = 0; i < signal.size(); ++i)
        ASSERT_FLOAT_EQ(signal[i], expected[i]);
}

TEST(EMA, EMA_overflow_init) {
    using namespace std;
    constexpr uint16_t maximum = (1 << 16) - 1;
    EMA<6, uint16_t, uint32_t> ema(maximum);
    EXPECT_EQ(ema(maximum), maximum);
    EXPECT_EQ(ema(maximum), maximum);
    EXPECT_EQ(ema(maximum), maximum);
}

TEST(EMA, EMA_overflow_reset) {
    using namespace std;
    constexpr uint16_t maximum = (1 << 16) - 1;
    EMA<6, uint16_t, uint32_t> ema;
    EXPECT_EQ(ema(0), 0);
    EXPECT_EQ(ema(0), 0);
    EXPECT_EQ(ema(0), 0);
    ema.reset(maximum);
    EXPECT_EQ(ema(maximum), maximum);
    EXPECT_EQ(ema(maximum), maximum);
    EXPECT_EQ(ema(maximum), maximum);
}
